{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "64053c0f-3582-465b-9e4c-a83da332da88",
   "metadata": {},
   "source": [
    "# Find Label Errors in Multi-Label Classification Datasets\n",
    "\n",
    "This 5-minute quickstart tutorial demonstrates how to find potential label errors in multi-label classification datasets. In such datasets, each example is labeled as belonging to one *or more* classes (unlike in *multi-class classification* where each example can only belong to one class). For a particular example in such multi-label classification data, we say each class either applies or not. We may even have some examples where *no* classes apply. Common applications of this include image tagging (or document tagging), where multiple tags can be appropriate for a single image (or document). For example, a image tagging application could involve the following classes: [`copyrighted`, `advertisement`, `face`, `violence`, `nsfw`]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adaefc8b-b639-4bdf-af0d-337519e37ffc",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "Quickstart\n",
    "<br/>\n",
    "    \n",
    "cleanlab finds data/label issues based on two inputs: `labels` formatted as a list of lists of integer class indices that apply to each example in your dataset, and `pred_probs` from a trained multi-label classification model (which do not need to sum to 1 since the classes are not mutually exclusive). Once you have these, run the code below to find issues in your multi-label dataset:\n",
    "\n",
    "<div  class=markdown markdown=\"1\" style=\"background:white;margin:16px\">  \n",
    "    \n",
    "```ipython3 \n",
    "from cleanlab import Datalab\n",
    "\n",
    "# Assuming your dataset has a label column named 'label'\n",
    "lab = Datalab(dataset, label_name='label', task='multilabel')\n",
    "# To detect more issue types, optionally supply `features` (numeric dataset values or model embeddings of the data)\n",
    "lab.find_issues(pred_probs=pred_probs, features=features)\n",
    "\n",
    "lab.report()\n",
    "```\n",
    "\n",
    "    \n",
    "</div>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a6261a3-6ea1-44a6-ac91-d375c8aa5535",
   "metadata": {},
   "source": [
    "## 1. Install required dependencies and get dataset\n",
    "\n",
    "You can use `pip` to install all packages required for this tutorial as follows:\n",
    "\n",
    "```ipython3\n",
    "!pip install matplotlib\n",
    "!pip install \"cleanlab[datalab]\"\n",
    "# Make sure to install the version corresponding to this tutorial\n",
    "# E.g. if viewing master branch documentation:\n",
    "#     !pip install git+https://github.com/cleanlab/cleanlab.git\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7383d024-8273-4039-bccd-aab3020d331f",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Package installation (hidden on docs.cleanlab.ai).\n",
    "# Package versions we used: matplotlib==3.5.1\n",
    "\n",
    "dependencies = [\"cleanlab\", \"matplotlib\", \"datasets\"]\n",
    "\n",
    "if \"google.colab\" in str(get_ipython()):  # Check if it's running in Google Colab\n",
    "    %pip install cleanlab  # for colab\n",
    "    cmd = ' '.join([dep for dep in dependencies if dep != \"cleanlab\"])\n",
    "    %pip install $cmd\n",
    "else:\n",
    "    dependencies_test = [dependency.split('>')[0] if '>' in dependency \n",
    "                         else dependency.split('<')[0] if '<' in dependency \n",
    "                         else dependency.split('=')[0] for dependency in dependencies]\n",
    "    missing_dependencies = []\n",
    "    for dependency in dependencies_test:\n",
    "        try:\n",
    "            __import__(dependency)\n",
    "        except ImportError:\n",
    "            missing_dependencies.append(dependency)\n",
    "\n",
    "    if len(missing_dependencies) > 0:\n",
    "        print(\"Missing required dependencies:\")\n",
    "        print(*missing_dependencies, sep=\", \")\n",
    "        print(\"\\nPlease install them before running the rest of this notebook.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf9101d8-b1a9-4305-b853-45aaf3d67a69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import sklearn\n",
    "from sklearn.multiclass import OneVsRestClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from cleanlab import Datalab\n",
    "from cleanlab.internal.multilabel_utils import int2onehot, onehot2int"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fe047ed",
   "metadata": {},
   "source": [
    "Here we generate a small multi-label classification dataset for a quick demo. To see cleanlab applied to a real image tagging dataset, check out our [example](https://github.com/cleanlab/examples) notebook [\"Find Label Errors in Multi-Label Classification Data (CelebA Image Tagging)\"](https://github.com/cleanlab/examples/blob/master/multilabel_classification/image_tagging.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b283ecc-ba52-4bd7-81d8-5397966b1621",
   "metadata": {},
   "source": [
    "<details><summary>Code to generate dataset (can skip these details) **(click to expand)**</summary>\n",
    "    \n",
    "```ipython3\n",
    "# Note: This pulldown content is for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.\n",
    "    \n",
    "from cleanlab.benchmarking.noise_generation import (\n",
    "    generate_noise_matrix_from_trace,\n",
    "    generate_noisy_labels,\n",
    ")\n",
    "\n",
    "def make_multilabel_data(\n",
    "    means=[[-5, 3.5], [0, 2], [-3, 6]],\n",
    "    covs=[[[3, -1.5], [-1.5, 1]], [[5, -1.5], [-1.5, 1]], [[3, -1.5], [-1.5, 1]]],\n",
    "    boxes_coordinates=[[-3.5, 0, -1.5, 1.7], [-1, 3, 2, 4], [-5, 2, -3, 4], [-3, 2, -1, 4]],\n",
    "    box_multilabels=[[0, 1], [1, 2], [0, 2], [0, 1, 2]],\n",
    "    sizes=[100, 80, 100],\n",
    "    avg_trace=0.9,\n",
    "    seed=1,\n",
    "):\n",
    "    np.random.seed(seed=seed)\n",
    "    num_classes = len(means)\n",
    "    m = num_classes + len(\n",
    "        box_multilabels\n",
    "    )  # number of classes by treating each multilabel as 1 unique label\n",
    "    n = sum(sizes)\n",
    "    local_data = []\n",
    "    labels = []\n",
    "    test_data = []\n",
    "    test_labels = []\n",
    "    for i in range(0, len(means)):\n",
    "        local_data.append(np.random.multivariate_normal(mean=means[i], cov=covs[i], size=sizes[i]))\n",
    "        test_data.append(np.random.multivariate_normal(mean=means[i], cov=covs[i], size=sizes[i]))\n",
    "        test_labels += [[i]] * sizes[i]\n",
    "        labels += [[i]] * sizes[i]\n",
    "\n",
    "    def make_multi(X, Y, bx1, by1, bx2, by2, label_list):\n",
    "        ll = np.array([bx1, by1])  # lower-left\n",
    "        ur = np.array([bx2, by2])  # upper-right\n",
    "\n",
    "        inidx = np.all(np.logical_and(X.tolist() >= ll, X.tolist() <= ur), axis=1)\n",
    "        for i in range(0, len(Y)):\n",
    "            if inidx[i]:\n",
    "                Y[i] = label_list\n",
    "        return Y\n",
    "\n",
    "    X_train = np.vstack(local_data)\n",
    "    X_test = np.vstack(test_data)\n",
    "\n",
    "    for i in range(0, len(box_multilabels)):\n",
    "        bx1, by1, bx2, by2 = boxes_coordinates[i]\n",
    "        multi_label = box_multilabels[i]\n",
    "        labels = make_multi(X_train, labels, bx1, by1, bx2, by2, multi_label)\n",
    "        test_labels = make_multi(X_test, test_labels, bx1, by1, bx2, by2, multi_label)\n",
    "\n",
    "    d = {}\n",
    "    for i in labels:\n",
    "        if str(i) not in d:\n",
    "            d[str(i)] = len(d)\n",
    "    inv_d = {v: k for k, v in d.items()}\n",
    "    labels_idx = [d[str(i)] for i in labels]\n",
    "    py = np.bincount(labels_idx) / float(len(labels_idx))\n",
    "    noise_matrix = generate_noise_matrix_from_trace(\n",
    "        m,\n",
    "        trace=avg_trace * m,\n",
    "        py=py,\n",
    "        valid_noise_matrix=True,\n",
    "        seed=seed,\n",
    "    )\n",
    "    noisy_labels_idx = generate_noisy_labels(labels_idx, noise_matrix)\n",
    "    noisy_labels = [eval(inv_d[i]) for i in noisy_labels_idx]\n",
    "    return {\n",
    "        \"X_train\": X_train,\n",
    "        \"true_labels_train\": labels,\n",
    "        \"X_test\": X_test,\n",
    "        \"true_labels_test\": test_labels,\n",
    "        \"labels\": noisy_labels,\n",
    "        \"dict_unique_label\": d,\n",
    "        'labels_idx': noisy_labels_idx,\n",
    "\n",
    "    }\n",
    "\n",
    "def get_color_array(labels):\n",
    "    \"\"\"\n",
    "    This function returns a dictionary mapping multi-labels to unique colors\n",
    "    \"\"\"\n",
    "    dcolors ={'[0]': 'aa4400',\n",
    "             '[0, 2]': '55227f',\n",
    "             '[0, 1]': '55a100',\n",
    "             '[1]': '00ff00',\n",
    "             '[1, 2]': '007f7f',\n",
    "             '[0, 1, 2]': '386b55',\n",
    "             '[2]': '0000ff'}\n",
    "\n",
    "    return [\"#\"+dcolors[str(i)] for i in labels]\n",
    "\n",
    "def plot_data(data, circles, title, alpha=1.0,colors = []):\n",
    "    plt.figure(figsize=(14, 5))\n",
    "    done = set()\n",
    "    for i in range(0,len(data)):\n",
    "        lab = str(labels[i])\n",
    "        if lab in done:\n",
    "            label = \"\"\n",
    "        else:\n",
    "            label = lab\n",
    "            done.add(lab)\n",
    "        plt.scatter(data[i, 0], data[i, 1], c=colors[i], s=30,alpha=0.6, label = label)\n",
    "    for i in circles:\n",
    "        plt.plot(\n",
    "            data[i][0],\n",
    "            data[i][1],\n",
    "            \"o\",\n",
    "            markerfacecolor=\"none\",\n",
    "            markeredgecolor=\"red\",\n",
    "            markersize=14,\n",
    "            markeredgewidth=2.5,\n",
    "            alpha=alpha\n",
    "        )\n",
    "    _ = plt.title(title, fontsize=25)\n",
    "    plt.legend()\n",
    "```\n",
    "    \n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ff5c2f-bd52-44aa-b307-b2b634147c68",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "from cleanlab.benchmarking.noise_generation import (\n",
    "    generate_noise_matrix_from_trace,\n",
    "    generate_noisy_labels,\n",
    ")\n",
    "\n",
    "def make_multilabel_data(\n",
    "    means=[[-5, 3.5], [0, 2], [-3, 6]],\n",
    "    covs=[[[3, -1.5], [-1.5, 1]], [[5, -1.5], [-1.5, 1]], [[3, -1.5], [-1.5, 1]]],\n",
    "    boxes_coordinates=[[-3.5, 0, -1.5, 1.7], [-1, 3, 2, 4], [-5, 2, -3, 4], [-3, 2, -1, 4]],\n",
    "    box_multilabels=[[0, 1], [1, 2], [0, 2], [0, 1, 2]],\n",
    "    sizes=[100, 80, 100],\n",
    "    avg_trace=0.9,\n",
    "    seed=1,\n",
    "):\n",
    "    np.random.seed(seed=seed)\n",
    "    num_classes = len(means)\n",
    "    m = num_classes + len(\n",
    "        box_multilabels\n",
    "    )  # number of classes by treating each multilabel as 1 unique label\n",
    "    n = sum(sizes)\n",
    "    local_data = []\n",
    "    labels = []\n",
    "    test_data = []\n",
    "    test_labels = []\n",
    "    for i in range(0, len(means)):\n",
    "        local_data.append(np.random.multivariate_normal(mean=means[i], cov=covs[i], size=sizes[i]))\n",
    "        test_data.append(np.random.multivariate_normal(mean=means[i], cov=covs[i], size=sizes[i]))\n",
    "        test_labels += [[i]] * sizes[i]\n",
    "        labels += [[i]] * sizes[i]\n",
    "\n",
    "    def make_multi(X, Y, bx1, by1, bx2, by2, label_list):\n",
    "        ll = np.array([bx1, by1])  # lower-left\n",
    "        ur = np.array([bx2, by2])  # upper-right\n",
    "\n",
    "        inidx = np.all(np.logical_and(X.tolist() >= ll, X.tolist() <= ur), axis=1)\n",
    "        for i in range(0, len(Y)):\n",
    "            if inidx[i]:\n",
    "                Y[i] = label_list\n",
    "        return Y\n",
    "\n",
    "    X_train = np.vstack(local_data)\n",
    "    X_test = np.vstack(test_data)\n",
    "\n",
    "    for i in range(0, len(box_multilabels)):\n",
    "        bx1, by1, bx2, by2 = boxes_coordinates[i]\n",
    "        multi_label = box_multilabels[i]\n",
    "        labels = make_multi(X_train, labels, bx1, by1, bx2, by2, multi_label)\n",
    "        test_labels = make_multi(X_test, test_labels, bx1, by1, bx2, by2, multi_label)\n",
    "\n",
    "    d = {}\n",
    "    for i in labels:\n",
    "        if str(i) not in d:\n",
    "            d[str(i)] = len(d)\n",
    "    inv_d = {v: k for k, v in d.items()}\n",
    "    labels_idx = [d[str(i)] for i in labels]\n",
    "    py = np.bincount(labels_idx) / float(len(labels_idx))\n",
    "    noise_matrix = generate_noise_matrix_from_trace(\n",
    "        m,\n",
    "        trace=avg_trace * m,\n",
    "        py=py,\n",
    "        valid_noise_matrix=True,\n",
    "        seed=seed,\n",
    "    )\n",
    "    noisy_labels_idx = generate_noisy_labels(labels_idx, noise_matrix)\n",
    "    noisy_labels = [eval(inv_d[i]) for i in noisy_labels_idx]\n",
    "    return {\n",
    "        \"X_train\": X_train,\n",
    "        \"true_labels_train\": labels,\n",
    "        \"X_test\": X_test,\n",
    "        \"true_labels_test\": test_labels,\n",
    "        \"labels\": noisy_labels,\n",
    "        \"dict_unique_label\": d,\n",
    "        'labels_idx': noisy_labels_idx,\n",
    "\n",
    "    }\n",
    "\n",
    "def get_color_array(labels):\n",
    "    \"\"\"\n",
    "    This function returns a dictionary mapping multi-labels to unique colors\n",
    "    \"\"\"\n",
    "    dcolors ={'[0]': 'aa4400',\n",
    "             '[0, 2]': '55227f',\n",
    "             '[0, 1]': '55a100',\n",
    "             '[1]': '00ff00',\n",
    "             '[1, 2]': '007f7f',\n",
    "             '[0, 1, 2]': '386b55',\n",
    "             '[2]': '0000ff'}\n",
    "\n",
    "    return [\"#\"+dcolors[str(i)] for i in labels]\n",
    "\n",
    "def plot_data(data, circles, title, alpha=1.0,colors = []):\n",
    "    plt.figure(figsize=(14, 5))\n",
    "    done = set()\n",
    "    for i in range(0,len(data)):\n",
    "        lab = str(labels[i])\n",
    "        if lab in done:\n",
    "            label = \"\"\n",
    "        else:\n",
    "            label = lab\n",
    "            done.add(lab)\n",
    "        plt.scatter(data[i, 0], data[i, 1], c=colors[i], s=30,alpha=0.6, label = label)\n",
    "    for i in circles:\n",
    "        plt.plot(\n",
    "            data[i][0],\n",
    "            data[i][1],\n",
    "            \"o\",\n",
    "            markerfacecolor=\"none\",\n",
    "            markeredgecolor=\"red\",\n",
    "            markersize=14,\n",
    "            markeredgewidth=2.5,\n",
    "            alpha=alpha\n",
    "        )\n",
    "    _ = plt.title(title, fontsize=25)\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "672bfc2a",
   "metadata": {},
   "source": [
    "Some of the labels in our generated dataset purposely contain errors. The examples with label errors are circled in the plot below, which depicts the dataset. This dataset contains 3 classes, and any subset of these may be the given label for a particular example. We say this example has a label error if it is better described by an alternative  subset of the classes than the given label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dac65d3b-51e8-4682-b829-beab610b56d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_class = 3\n",
    "dataset = make_multilabel_data()\n",
    "labels = dataset['labels']\n",
    "true_errors = np.where(np.sum(int2onehot(dataset['true_labels_train'],3)!=int2onehot(dataset['labels'],3),axis=1)>=1)[0]\n",
    "plot_data(dataset['X_train'], circles=true_errors, title=f\"True label errors in multi-label dataset with {num_class} classes\", colors = get_color_array(labels),alpha=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "144ad4c2-49bb-4147-a743-a83ed1656a11",
   "metadata": {},
   "source": [
    "## 2. Format data, labels, and model predictions\n",
    "\n",
    "In multi-label classification, each example in the dataset is labeled as belonging to one **or more** of *K* possible classes (or none of the classes at all). To find label issues, cleanlab requires predicted class probabilities from a trained classifier. \n",
    "Here we produce out-of-sample `pred_probs` by employing cross-validation to fit a multi-label **RandomForestClassifier** model via sklearn's [OneVsRestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) framework. \n",
    "Make sure that the columns of your `pred_probs` are properly ordered with respect to the ordering of classes, which for Datalab is: lexicographically sorted by class name.\n",
    "`OneVsRestClassifier` offers an easy way to apply any multi-class classifier model from sklearn to multi-label classification tasks. It is done for simplicity here, but we advise against this approach as it does not properly model dependencies between classes.\n",
    "\n",
    "To instead train a state-of-the-art Pytorch neural network for multi-label classification and produce `pred_probs` on a real image dataset (that properly account for dependencies between classes), see our [example](https://github.com/cleanlab/examples) notebook [\"Train a neural network for multi-label classification on the CelebA dataset\"](https://github.com/cleanlab/examples/blob/master/multilabel_classification/pytorch_network_training.ipynb). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5fa99a9-2583-4cd0-9d40-015f698cdb23",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 0\n",
    "random.seed(SEED)\n",
    "y_onehot = int2onehot(labels, K=num_class)  # labels in a binary format for sklearn OneVsRestClassifier\n",
    "single_class_labels = [random.choice(i) for i in labels]  # used only for stratifying the cross-validation split \n",
    "clf = OneVsRestClassifier(RandomForestClassifier(random_state=SEED))\n",
    "pred_probs = np.zeros(shape=(len(labels), num_class))\n",
    "kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)\n",
    "\n",
    "for train_index, test_index in kf.split(X=dataset['X_train'], y=single_class_labels):\n",
    "    clf_cv = sklearn.base.clone(clf)\n",
    "    X_train_cv, X_test_cv = dataset['X_train'][train_index], dataset['X_train'][test_index]\n",
    "    y_train_cv, y_test_cv = y_onehot[train_index], y_onehot[test_index]\n",
    "    clf_cv.fit(X_train_cv, y_train_cv)\n",
    "    y_pred_cv = clf_cv.predict_proba(X_test_cv)\n",
    "    pred_probs[test_index] = y_pred_cv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41c1efab",
   "metadata": {},
   "source": [
    "`pred_probs` should be 2D array whose rows are length-*K* vectors for **each** example in the dataset, representing the model-estimated probability that this example belongs to each class. Since one example can belong to multiple classes in multi-label classification, these probabilities need not sum to 1. For the best label error detection performance, these `pred_probs` should be out-of-sample (from a copy of the model that never saw this example during training, e.g. produced via cross-validation).\n",
    "\n",
    "`labels` should be a list of lists, whose *i*-th entry is a list of (integer) class indices that apply to the *i*-th example in the dataset. If your classes are represented as string names, you should map these to integer indices. The label for an example that belongs to none of the classes should just be an empty list `[]`.\n",
    "\n",
    "Once you have `pred_probs` and `labels` appropriately formatted, you can find/analyze label issues in any multi-label dataset via `Datalab`!\n",
    "\n",
    "Here's what these look like for the first few examples in our synthetic multi-label dataset: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac1a60df",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_to_display = 3  # increase this to see more examples\n",
    "\n",
    "print(f\"labels for first {num_to_display} examples in format expected by cleanlab:\")\n",
    "print(labels[:num_to_display])\n",
    "print(f\"pred_probs for first {num_to_display} examples in format expected by cleanlab:\")\n",
    "print(pred_probs[:num_to_display])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a973506-c30e-4409-ac65-495537d13730",
   "metadata": {},
   "source": [
    "## 3. Use cleanlab to find label issues \n",
    "\n",
    "Based on the given `labels` and `pred_probs` from a trained model, cleanlab can quickly help us find label errors in our dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d09115b6-ad44-474f-9c8a-85a459586439",
   "metadata": {},
   "outputs": [],
   "source": [
    "lab = Datalab(\n",
    "    data={\"labels\": labels},\n",
    "    label_name=\"labels\",\n",
    "    task=\"multilabel\",\n",
    ")\n",
    "\n",
    "lab.find_issues(\n",
    "    pred_probs=pred_probs,\n",
    "    issue_types={\"label\": {}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "439c003e",
   "metadata": {},
   "source": [
    " Here we request that the indices of the examples identified with label issues be sorted by cleanlab’s self-confidence score, which is used to measure the quality of individual labels. The returned `issues` are a list of indices corresponding to the examples in your dataset that cleanlab finds most likely to be mislabeled. These indices are sorted by the *self-confidence* label quality score, with the lowest quality labels at the start."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c18dd83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_issues = lab.get_issues(\"label\")\n",
    "\n",
    "issues = label_issues.query(\"is_label_issue\").sort_values(\"label_score\").index.values\n",
    "\n",
    "print(f\"Indices of examples with label issues:\\n{issues}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6af5833",
   "metadata": {},
   "source": [
    "Let's look at the samples that cleanlab thinks are most likely to be mislabeled. You can see that cleanlab was able to identify most of `true_errors` in our small dataset (despite not having access to this variable, which you won't have in your own applications)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fffa88f6-84d7-45fe-8214-0e22079a06d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_data(dataset['X_train'], circles=issues, title=f\"Inferred label issues in multi-label dataset with {num_class} classes\", colors = get_color_array(labels), alpha = 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32465521",
   "metadata": {},
   "source": [
    "### Label quality scores\n",
    "\n",
    "The above code identifies which examples have label issues and sorts them by their label quality score. We can also take a look at this label quality score for each example in the dataset, which estimates our confidence that this example has been correctly labeled. These scores range between 0 and 1 with smaller values indicating examples whose label seems more suspect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1198575",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = label_issues[\"label_score\"].values\n",
    "\n",
    "print(f\"Label quality scores of the first 10 examples in dataset:\\n{scores[:10]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d65af827-aeda-4b6b-9ae7-b1f0b84700d6",
   "metadata": {},
   "source": [
    "### Data issues beyond mislabeling (outliers, duplicates, drift, ...)\n",
    "\n",
    "While this tutorial focused on label issues, cleanlab's `Datalab` object can automatically detect many other types of issues in your dataset (outliers, near duplicates, drift, etc).\n",
    "Simply remove the `issue_types` argument from the above call to `Datalab.find_issues()` above and `Datalab` will more comprehensively audit your dataset.\n",
    "Refer to our [Datalab quickstart tutorial](./datalab/datalab_quickstart.html) to learn how to interpret the results (the interpretation remains mostly the same across different types of ML tasks)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d65af827-aeda-4b6b-9ae7-b1f0b84700d5",
   "metadata": {},
   "source": [
    "### How to format labels given as a one-hot (multi-hot) binary matrix?\n",
    "\n",
    "For multi-label classification, cleanlab expects labels to be formatted as a list of lists, where each entry is an integer corresponding to a particular class. Here are some functions you can use to easily convert labels between this format and a binary matrix format commonly used to train multi-label classification models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49161b19-7625-4fb7-add9-607d91a7eca1",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_binary_format = int2onehot(labels, K=num_class)\n",
    "labels_list_format = onehot2int(labels_binary_format)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a58200c8",
   "metadata": {},
   "source": [
    "### Estimate label issues without Datalab \n",
    "If you prefer to directly run the same lower-level mathematical functions Datalab uses to detect label issues, you can do so outside of Datalab via the methods in the `cleanlab.multilabel_classification` module such as: [multilabel_classification.filter.find_label_issues](../cleanlab/multilabel_classification/filter.html#cleanlab.multilabel_classification.filter.find_label_issues), [multilabel_classification.rank.get_label_quality_scores](../cleanlab/multilabel_classification/rank.html#cleanlab.multilabel_classification.rank.get_label_quality_scores) \n",
    "\n",
    "### Application to Real Data \n",
    "\n",
    "To see cleanlab applied to a real image tagging dataset, check out our [example](https://github.com/cleanlab/examples) notebook [\"Find Label Errors in Multi-Label Classification Data (CelebA Image Tagging)\"](https://github.com/cleanlab/examples/blob/master/multilabel_classification/image_tagging.ipynb). That example also demonstrates how to use a state-of-the-art Pytorch neural network for multi-label classification with image data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1bd9f83",
   "metadata": {},
   "source": [
    "\n",
    "## Spending too much time on data quality?\n",
    "\n",
    "Using this open-source package effectively can require significant ML expertise and experimentation, plus handling detected data issues can be cumbersome.\n",
    "\n",
    "That’s why we built [Cleanlab Studio](https://cleanlab.ai/blog/data-centric-ai/) -- an automated platform to find **and fix** issues in your dataset, 100x faster and more accurately.  Cleanlab Studio automatically runs optimized data quality algorithms from this package on top of cutting-edge AutoML & Foundation models fit to your data, and helps you fix detected issues via a smart data correction interface. [Try it](https://cleanlab.ai/) for free!\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img src=\"https://raw.githubusercontent.com/cleanlab/assets/master/cleanlab/ml-with-cleanlab-studio.png\" alt=\"The modern AI pipeline automated with Cleanlab Studio\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1a2c008",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Note: This cell is only for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.\n",
    "\n",
    "A = set(issues)\n",
    "B = set(true_errors)\n",
    "jaccard = len(A.intersection(B)) / len(A.union(B))\n",
    "if not jaccard > 0.7:\n",
    "    raise Exception(\"issues does not overlap much with the true errors\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
